import torch
from typing import List, Optional, Tuple, Union
from transformers.modeling_attn_mask_utils import AttentionMaskConverter

def _make_causal_mask(
    bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
    """
    Make causal mask used for bi-directional self-attention.
    """
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

        
class DDESKVCache_LayerWise:
    def __init__(
        self,
        hh_size=1024,
        recent_size=128,
        k_seq_dim=2,
        v_seq_dim=2,
    ):
        self.hh_size = hh_size
        self.recent_size = recent_size
        self.cache_size = hh_size + recent_size
        self.k_seq_dim = k_seq_dim
        self.v_seq_dim = v_seq_dim
        self.hh_score = None
        # self.bk_size = recent_size
        
    def __call__(self, past_key_values, attn_score_cache, layer_index):
    
        self._update_hh_score(attn_score_cache)
        if past_key_values is None:
            return None
        seq_len = past_key_values[layer_index][0].size(self.k_seq_dim)
        if seq_len <= self.cache_size:
            return past_key_values
        # hh-selection
        bsz, num_heads, _, head_dim = past_key_values[layer_index][0].shape

        select_hh_scores = self.hh_score[:, :seq_len - self.recent_size]
        _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
        keep_topk = keep_topk.sort().values

        # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
        keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
        keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
        mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values.key_cache[layer_index].device)
        mask = mask.scatter(-1, keep_idx, 1)
        
        k_hh_recent = past_key_values[layer_index][0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
        v_hh_recent = past_key_values[layer_index][1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
        self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
        past_key_values.key_cache[layer_index] = k_hh_recent
        past_key_values.value_cache[layer_index] = v_hh_recent
        return past_key_values 
        
    def evict_for_space(self, past_key_values, num_coming, layer_index):
        if past_key_values is None:
            return None
        seq_len = past_key_values[0][0].size(self.k_seq_dim)
        if seq_len + num_coming <= self.cache_size:
            return past_key_values

        # hh-selection
        bsz, num_heads, _, head_dim = past_key_values[layer_index][0].shape

        select_hh_scores = self.hh_score[:, :seq_len - self.recent_size + num_coming]
        _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
        keep_topk = keep_topk.sort().values

        # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
        keep_recent = torch.arange(seq_len - self.recent_size + num_coming, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
        keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)

        mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
        mask = mask.scatter(-1, keep_idx, 1)

        k_hh_recent = past_key_values[layer_index][0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
        v_hh_recent = past_key_values[layer_index][1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)

        self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
        past_key_values.key_cache[layer_index] = k_hh_recent
        past_key_values.value_cache[layer_index] = v_hh_recent
        return past_key_values 
        
    def _update_hh_score(self, attn_score_cache):
        num_new_tokens = attn_score_cache.shape[2]
        if self.hh_score is None:
            self.hh_score = attn_score_cache.sum(0).sum(1)
        else:
            attn_score_cache = attn_score_cache.sum(0).sum(1)
            attn_score_cache[:, :-num_new_tokens] += self.hh_score
            self.hh_score = attn_score_cache

    def _clean_scores(self):
        self.hh_score = None


